之前对Pytorch的索引方式一直有点疑惑,昨天在小伙伴的帮助下对其有了更加深刻的理解。下面对这些进行一下总结。另外,值得注意的是,Pytorch号称直接对接的Numpy,因此下面的索引方法理论上也可以适用于Numpy的索引方式。
Pytorch的tensor索引方式有三种:分别为按照long tensor
、按照bool tensor
和按照byte tensor
。下面分别进行介绍。
首先,说明一下,Pytorch默认打印的tensor只有四位小数,可以使用torch.set_printoptions(precision=8)
多打印几个小数。
long tensor
首先看下面的代码。
1 | import torch |
代码解释:当b
为long tensor
时候,a[b]
的实质为a[b, :]
,也就是取出b
中元素作为a
的行索引,而默认取出所有列。这可以使用下面代码解释:
1 | a[0] |
在这个代码里面,a[0]
与a[0, :]
的输出是一致的,所以也就是a{0]
中的0
作为了a
的行索引,而默认取出所有列。同理,a[[0,1,1],:]
中的[0,1,1]
分别作为了a
的行索引。
bool tensor和byte tensor
对于bool tensor
,首先看下面的代码。
1 | import torch |
代码解释:当b
为bool tensor
时候,b
中每一个位置的bool值表示是否取a
对应位置的值,当为True
的时候表示取出该值,当为False
的时候,表示不取该值。
对于byte tensor
,可以看下面代码:
1 | c = b.byte() |
可以看出来,bool tensor
和byte tensor
作为索引列表时效果是一样的,只是不推荐使用byte tensor
而已。